{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 这是一个业余项目，跑一个 MNIST 上的十分类 Logistic 分类器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(58000, 784)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAACRCAYAAADTnUPWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAFDZJREFUeJzt3Xl0VFWeB/DvL5WwKntIA0EWE2RzoWURtVtbZEZUwHan1cZxQQVFOVEEnWmmT/exdZyxB1xg8Kio06I2qCC0okNjqw2N4AoYCIishlVkC0uWO3+kfLd+RV5SqarUcuv7OadO7q3fq/du8qvcvNy67z4xxoCIiNJfVrIbQERE8cEOnYjIEezQiYgcwQ6diMgR7NCJiBzBDp2IyBFOd+gisklELo5wWyMiBVEeJ+rXUv0xr25iXmPndIeeykSkjYjsFpGPk90Wip2I/KeIrBeRgyKyVkR+new2UezSLa/ZyW5ABnsMQDH4R9UVhwEMB1ACYACAd0VkgzFmaXKbRTFKq7xmTGciIgNFZJmI/CAipSLylIg0CtvsUhHZKCJ7RORxEckKef0tIlIsIvtEZJGIdImhLYMB9AXwQrT7oGqpkldjzBRjzFpjTJUxZjmAjwAMjuFby2jMa3QypkMHUAlgAoB2qE7IEABjw7b5JYD+AH4KYCSAWwBARK4A8BCAKwHkojqps2s6iIj8SkS+8muEiAQAPA3gbgBcdyF2KZHXsG2bovpsbk09vxeymNdoGGOcfQDYBOBin9h9AN4MqRsAl4TUxwJYHCy/A+DWkFgWgDIAXUJeWxBhmyYAmB4s3wzg42T/nNLtkYp5DWvDiwDeBSDJ/lml04N5jf2RMWfoItJDRBaIyA4ROQDgEVT/9Q+1NaS8GUDHYLkLgKnBf/9+APA9AAHQqZ5t6AhgPICHo/ke6ESpkNew9jyO6uG0a02wF6D6Y16jkzEdOoDpANYCKDTGtED1v2QStk3nkPIpAL4LlrcCuMMY0yrk0dTU/4ORgQA6APhaRHYAmApgYPBNG6jvN0QAUiOvAAAR+S2AYQD+yRhzIJp9kId5jUImdegnAzgA4JCI9ARwVw3bPCAirUWkM4B7AbwWfH4GgMki0gcARKSliFwTRRveAdAVwFnBx28AfA7gLGNMZRT7o9TIK0RkMoBfARhqjNkbzT5IYV6jkEkd+v2oTsxBAM/CJj/UPACfAvgCwEIAzwGAMeZNVE8zfDX4799qVP/FPoGI3CAiNX5oYow5ZozZ8eMDwH4A5cEyRSfpeQ16BNVnietF5FDw8VB03xKBeY2KpPBwEBER1UMmnaETETmNHToRkSPYoRMROSKmDl1ELhGRdSKyQUQmxatRlFzMq7uYW7dF/aFocN50CYChALYBWAFglDHma7/XNJLGpgmaR3U8ip+jOIzj5lj4nF4AzGs6qy2vQP1zy7ymjoPYt8cYk1vXdrGstjgQwAZjzEYAEJFXUb2egu8vfhM0xyAZEsMhKR6Wm8W1hZnXNFVHXoF65pZ5TR3/Z+ZsjmS7WIZcOkFfersNNVxaKyJjRGSliKwsx7EYDkcJwry6q87cMq/pLZYOvaZ/7U4YvzHGzDTG9DfG9M9B4xgORwnCvLqrztwyr+ktlg59G/RaCvmwaylQ+mJe3cXcOi6WDn0FgEIR6RZceP56APPj0yxKIubVXcyt46L+UNQYUyEidwNYBCAA4HljTOou/E4RYV7dxdy6L6Z7ihpj/gLgL3FqC6UI5tVdzK3beKUoEZEj2KETETmCHToRkSPYoRMROYIdOhGRI9ihExE5gh06EZEj2KETETmCHToRkSPYoRMROSKmS/+JUt3O8ed6ZXPRPhUbXbDcK49p5Xv/DgDAhO32Rg/bL2+mYpW7d8fSREqgQNs2qi4tW3jlLVd1VLGj7ezKwgW//VLFqsrKGqB1seMZOhGRI9ihExE5gkMuDSjQq1DVixa84ZUfu+lGFZOl+l86ilx2vr2LWtkL+i29oveTXrm4vFzFJm68yiu/t6uXis0seE3VZ+R/ZMt/66Ji83u3rWeLqSFl9e3plddPbqpit5y+VNWL2i6KaJ+98u5U9cKbP42ydQ2LZ+hERI5gh05E5Ah26EREjsjYMXQ5u4+qZ5Ud98qVxevjc5BnDqvqD5XNvXL2D0dUrDI+R8xI/RZs8cpXt1ypYj3m3eOVe/9ui4qZ0u2++xwz4C5Vnz5nule+reVGFfvjf13mlU8t+kcELaZYyYDTvfKGCQEV++D8p7xybqCximWFncMuLGvtlTcea69i41qv88ov//xZFfvdgNFe2axYFWmzGxzP0ImIHMEOnYjIERk15PL9LYO98utTHlexYbMmeuUuU6Ifcjk2bIBXnlc4TcV+OmeCVy74mv+aR+vQteeo+pTcp73yOZ/dpGI9xn7ilSvqcYzwf6OHvHW/V15/9TMq9tjwV7zyzKLu9TgK1SaQm+uVS6Z2UrG3z7U56J6TE/bKxvDzwoHOqv7WVed75arGej/jFtghl/6N9aDokTw7HbKJ79ESj2foRESOYIdOROQIduhERI7IqDH0Rtfu9MojPhujYl2mLA3fPCqbR4hXbiaNVKz7nKNxOUamqwwbMn3pgB1fDcxtmMvwT/1zSO6u1rHc7AP2+O308Sv37G2Q9mSC7TfapTPWXDA1LBo+bl6z/w0fM7/iXFWvXFfilaWfnsqcjniGTkTkiDo7dBF5XkR2icjqkOfaiMj7IrI++LV1bfug1MO8uou5zVyRDLnMAvAUgJdCnpsEYLEx5lERmRSsPxj/5sXXH3rY1Q5Hvzemli2jd9JPDnnlLEgtWybdLKRpXlu/pacUzn27h40dWNYgxwwc9Z/0eF7jKq+8+fbTVCz/D/EZyqunWUjT3IbqNGJTRNvNOfQTVX+ixN6MJG+iUbHKdf5Tkved3sI3li7qPEM3xnwI4Puwp0cCeDFYfhHAFXFuFzUw5tVdzG3minYMPc8YUwoAwa/t/TYUkTEislJEVpbjWJSHowRhXt0VUW6Z1/TW4B+KGmNmGmP6G2P659RyBRelF+bVTcxreot22uJOEelgjCkVkQ4AdsWzUfESKOim6rlZDT+e+d7ZM71yadgSioEj9o45emQvZaRFXqsOH657o3hbZcden/xBX95/Tyu7+mJZd31XpBSSFrlVbrd/UHqPu0eFOr9vf7mar9mhYu0226mI9VnFtCwvpT/ziki0Z+jzAfy4fuRoAPPi0xxKMubVXcxtBohk2uJsAMsAnCYi20TkVgCPAhgqIusBDA3WKY0wr+5ibjNXnUMuxphRPqEhPs+njMoN36r6quMdvLI009PQspo188pVZWVRH7N9wO7n93v66uCXdvW20OPFesxopHNek8Ecsx8QHqpMpfX1TuRKbkN/fwsmfOu7XX1W0axN+YCDcdpT8vBKUSIiR7BDJyJyBDt0IiJHZNRqiw/+9TqvXDJ8uoqd9+frvXLbyfqms1VfrY3qeKNarVD1bz4aZst/7K1iJ73OOxilstDPPNpl7/bf7lDAN0aJseU3dkXFimZhE4TDZyaGhK8s9F824u5tF6p603c/q2kXScczdCIiR7BDJyJyREYNufT+981eeUg3fZeCZWe95pUD7+q/cwM/v8Yrlx3TN604L3+jqgfEvvbU7KYq9sl7dhpjl9eTsgofRcn0tleH3t7yY9/tTlkU+bWJ2fn2xhz7z8lXsR2D7Puo4DU9nc6sXI1ME2ihV0I8OtDe/CJn8k4V+6rnk777yRE9JFZu/PO15IgdZts25hQVMxXF/o1NIp6hExE5gh06EZEj2KETETkio8bQK3bYsbZml+lvvf8dd3vl5sP16m2X5a/x3WffpttUvdLYu9ecO3mcinWdvdIrp9JUJzpxKQYUdlHV7RdEdjebof/xoaq/dMtAr3xjTz2N9YymS7zyZc0OqdimCrsUxIjud6hY/lURNSXtSGO9XO/xC073yhOeeVnFftF0sVfeWanXbV9yxN5d7zclI1Vsdp9Zqt4x23+J4CZZduXMjde2UrHu6+zyD1VHU+fm7zxDJyJyBDt0IiJHsEMnInJERo2hhzIVetHN9k+HzAt/Wm+7FHrueag3x9+o6lc8+JRXbvfhdhWrKD9ez1ZSTbJOPlnVpbNdFnnX4LYqtneAnWc8apD/8grtG32n6ve08p9rXpv72qxS9dPOKvXddsLbv/bKUxfqOx012mnvypS/2v8znHSX1cSORe+9rp+KffTINN/X9Zlt72CUv0TPJW+80H5W0baD/mxi9qKzVb2orf+c/kGNbU6+ulm3ZfDW8V4576UvVSzRS2GH4hk6EZEj2KETETkiY4dc4uVg16q6N6J6Cx9WWft4L698/8/fUbE7W/4tqmN8U3HEK28q19PSjhg9PNZU/Ifden34L175lJn60vLAks/CN/cUwH8IyNV3VfjUxLVPnGHLI/2HWEauu0LVezxul9yo3Knvd53d2S6jcOb8LSr2QNuvVX1/lc3zoLlFKtahp93v4tNfU7Fl/2bbet2oy1VszzQ73bLJXv+bhgc+8H9vRItn6EREjmCHTkTkCHboRESO4Bh6PYVfIj7qor+r+hP77LKeVXv3JaRNLmq6sImqbzh1hlfeV3VExS5da+9EtX5rnop1XGDf4oGjesGF5iV7vXJlyTcqtqlYj8ve2sIu8fDqoVwVKxhrx2kr9zHn4STb5mDdf5+pYmtH2DnC2yr0Jfwj/meiV+76vM5PRci4efnFeipi38c+98pT2n+qYi8c0Es6vPzwcK9c8Ib+TCPQzk6BvXDoPSp2+Lr9XvnNfs+qWP40/+UEFhy2+5zZo7vvdtHiGToRkSPYoRMROYJDLvWUlaf/3Z6S+5Gqn/bWWK9ceHB5QtrkojcK3lf11w/ZFfRmjrlNxUKnfxVCr35Zm6ocOxWxZMZAFbu0+ROq/o9jdqjthTv1Cn6BffGffuaSrQ/Yn+3aEVNV7LuQYZZrHn1Axbq+Zacmfn9RNxUzN9pprXP66n3mBuyQR59X9VBJj5l7VL3ZOv/f0co9dkiuxey9KtZiti1fPXaiiuVdvRm+ikKnx8b/CmCeoRMROaLODl1EOovIEhEpFpE1InJv8Pk2IvK+iKwPfm1d174odTCvbmJeM1skZ+gVAIqMMb0AnANgnIj0BjAJwGJjTCGAxcE6pQ/m1U3MawarcwzdGFMKoDRYPigixQA6ARgJ4MLgZi8C+ADAgw3SyhSy5apOtca7za2oNZ4qUj2voXd+AoDiI/bnnv13vUJepHd/ymreXNWr5tuT1A09Z6jYvipR9UlFd3nlZktS97ORVMzr9Nuf8Y01CfkxD79T3+2p03g7BXR0i7drOYKeJtjnFbsSYsFkfZeoyor4/362f2apqhv/bxfA9tqCMavXGLqIdAXQD8ByAHnBN8+Pb6L2Pq8ZIyIrRWRlOY7VtAklGfPqJuY180TcoYvISQDmArjPGHMg0tcZY2YaY/obY/rnwH/CPSUH8+om5jUzRTRtUURyUP3m+JMx5o3g0ztFpIMxplREOgDY5b8Hd5SdeaTWePZfP601nkpSOa/PHchX9X9tZ4dZ+r4yWsU6trZX7X27pqOKnbzJnrPcdttCFRvT6gOvXLRjsIqtLjpD1Zt9kLrDLOFSLa8fHurplQc11jcAaRMyxfChdl/47uPytVeq+pZl9v3Rfc5+FStYY38Hw29k47pIZrkIgOcAFBtjQifnzgfw42/WaADz4t88aijMq5uY18wWyRn6eQBuArBKRH78E/oQgEcBvC4itwLYAuCahmkiNRDm1U3MawaLZJbLxwDEJzwkvs2hRGFe3cS8ZjZe+h+Bqp/Zm9eWXPScio3dfn7Y1rWPsVNk5vbSkzAenWZXxfvkl/qy/BwJGTnsCV+XrNI39H7l98O8covZYSvtgZfzx8vSX9jPNQbdcJGK7T/T3jEoe3eOivWYYaf4Ze/QQ/5dj271yq7e3SkavPSfiMgR7NCJiBzBIZcIHG1nV+WrCrsu8b1VfVS9B1YmpE2ZpnC8nTZ4w/jzotpHC3wT9kx4nRpC5d7vvXLeNH1VZV74xiEya8JhfPAMnYjIEezQiYgcwQ6diMgRHEOPwLZ/tuPmnx/Xk6R6Tdyo6pUJaRER0Yl4hk5E5Ah26EREjuCQSz3N2vMzVQ+dkkVElEw8QycicgQ7dCIiR7BDJyJyBMfQI9Djzk+8Mi8WJ6JUxTN0IiJHsEMnInIEO3QiIkewQycicgQ7dCIiR7BDJyJyhBhj6t4qXgcT2Q1gM4B2APYk7MC1y8S2dDHG5MZrZ8xrnZjX+MnUtkSU24R26N5BRVYaY/on/MA1YFviJ5Xaz7bETyq1n22pHYdciIgcwQ6diMgRyerQZybpuDVhW+InldrPtsRPKrWfbalFUsbQiYgo/jjkQkTkCHboRESOSGiHLiKXiMg6EdkgIpMSeezg8Z8XkV0isjrkuTYi8r6IrA9+bZ2AdnQWkSUiUiwia0Tk3mS1JR6YV9UWZ3LLvKq2pEVeE9ahi0gAwNMAhgHoDWCUiPRO1PGDZgG4JOy5SQAWG2MKASwO1htaBYAiY0wvAOcAGBf8WSSjLTFhXk/gRG6Z1xOkR16NMQl5ABgMYFFIfTKAyYk6fshxuwJYHVJfB6BDsNwBwLoktGkegKGp0BbmlbllXtM3r4kccukEYGtIfVvwuWTLM8aUAkDwa/tEHlxEugLoB2B5stsSJebVR5rnlnn1kcp5TWSHLjU8l9FzJkXkJABzAdxnjDmQ7PZEiXmtgQO5ZV5rkOp5TWSHvg1A55B6PoDvEnh8PztFpAMABL/uSsRBRSQH1W+MPxlj3khmW2LEvIZxJLfMa5h0yGsiO/QVAApFpJuINAJwPYD5CTy+n/kARgfLo1E9NtagREQAPAeg2BjzRDLbEgfMawiHcsu8hkibvCb4g4RLAZQA+AbAw0n4IGM2gFIA5ag+A7kVQFtUfzq9Pvi1TQLacT6q/339CsAXwcelyWgL88rcMq/u5JWX/hMROYJXihIROYIdOhGRI9ihExE5gh06EZEj2KETETmCHToRkSPYoRMROeL/Acsn/QsXkunNAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 导入并观察数据\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np \n",
    "import scipy.io as scio\n",
    "\n",
    "# 从 训练集 中分出一部分作为 val set\n",
    "train_data = scio.loadmat('hw4_lr/train_imgs.mat')['train_imgs'][0:58000]\n",
    "train_labels = scio.loadmat('hw4_lr/train_labels.mat')['train_labels'][0:58000]\n",
    "val_data = scio.loadmat('hw4_lr/train_imgs.mat')['train_imgs'][58000:60000]\n",
    "val_labels = scio.loadmat('hw4_lr/train_labels.mat')['train_labels'][58000:60000]\n",
    "test_data = scio.loadmat('hw4_lr/test_imgs.mat')['test_imgs']\n",
    "test_labels = scio.loadmat('hw4_lr/test_labels.mat')['test_labels']\n",
    "\n",
    "print(np.shape(train_data))\n",
    "# 数据已经预处理称为了一个 28 * 28 的行向量，我们将它恢复成 28 x 28 并观察。\n",
    "plt.figure()\n",
    "plt.subplot(1,3,1)\n",
    "plt.imshow(train_data[26].reshape((28,28)))\n",
    "plt.title('label: %s' % train_labels[0][26])\n",
    "plt.subplot(1,3,2)\n",
    "plt.imshow(train_data[16].reshape((28,28)))\n",
    "plt.title('label: %s' % train_labels[0][16])\n",
    "plt.subplot(1,3,3)\n",
    "plt.imshow(train_data[5].reshape((28,28)))\n",
    "plt.title('label: %s' % train_labels[0][5])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAACRCAYAAADTnUPWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAHNVJREFUeJztnXuMXVd1xr9179x5ejx+xE7sxInzcGICVElx0qRUAhUShdAqtCotULWpiJRSoCWIViQgtapUoVRUqEigVqGhCRLi0QJKRAmPWlRtVB4xJS9iJ46dlxM/4ldsj+d57+4fc5m91rpz9tx75859nPl+0mjOmX3O2fucdc6es7+z9loSQgAhhJDep9DpBhBCCGkN7NAJISQnsEMnhJCcwA6dEEJyAjt0QgjJCezQCSEkJ+S6QxeR50Xk7XVuG0TksibraXpf0ji0az6hXZdOrjv0bkZE1onIqyLycKfbQpaOiPyDiOwVkdMiskdE/rjTbSJLp9fs2tfpBqxg/h7AbvCfal4YB/DbAJ4BcA2A74rIsyGE/+1ss8gS6Sm7rpjORESuFZEfichJETkoIp8TkX632c0isl9EjorIp0WkoPZ/v4jsFpETIvI9EbloCW25HsAbAPxrs8cgc3SLXUMIfxNC2BNCqIQQfgLgfwBcv4RTW9HQrs2xYjp0AGUAHwVwDuYM8jYAH3Tb/A6AHQB+FcAtAN4PACLyLgCfAPC7ADZgzqhfWagSEXmfiDye1QgRKQL4PIAPA2DchaXTFXZ12w5h7m3uFw2eC4nQrs0QQsjtD4DnAbw9o+wOAN9S6wHATWr9gwB2VpcfAnCbKisAOAvgIrXvZXW26aMA/qm6/CcAHu70deq1n260q2vD/QC+C0A6fa166Yd2XfrPinlDF5HLReTbInJIRE4B+BTm/vtrXlLLLwDYXF2+CMBnq8O/kwCOAxAA5zfYhs0A/gLAJ5s5B1JLN9jVtefTmJPTfj9UewHSOLRrc6yYDh3APwHYA2BbCGE15oZk4rbZopYvBPBKdfklAH8aQlijfoZC4x9GrgWwCcBTInIIwGcBXFu9aYuNnhAB0B12BQCIyN8CeAeAG0MIp5o5BpmHdm2CldShjwI4BeCMiGwH8GcLbPNXIrJWRLYA+AiAr1X//s8A7hKR1wOAiIyJyLubaMNDALYCuKr689cAfg7gqhBCuYnjke6wK0TkLgDvA3BDCOFYM8cgBtq1CVZSh/6XmDPMaQBfQDS+5gEAPwPwKID/AHAvAIQQvoU5N8OvVod/T2LuP3YNIvKHIrLgR5MQwlQI4dAvfwC8BmCmukyao+N2rfIpzL0l7hWRM9WfTzR3SgS0a1NIF8tBhBBCGmAlvaETQkiuYYdOCCE5gR06IYTkhCV16CJyk4g8LSLPisidrWoU6Sy0a36hbfNN0x9Fq37TzwC4AcABAI8AeG8I4amsfUoDI2FgZF1T9ZHWMTV+HDNT496nF0Bzdu0vDoWh0tiytJXUz8TMa5guTyxoV6Bx2/b3DYeh/jXL0lbSGKcmDh4NIWxYbLulRFu8FsCzIYT9ACAiX8VcPIXMB39gZB3eeOMdS6iStIInvv+PqeKG7TpUGsP1F3Z1VFEAgFSa9+gKhcx+smv40YtfWmyThmw71L8G111+W0vbSJrj+4/93Qv1bLcUyeV82Km3B7DA1FoRuV1EdonIrpmp8SVUR9pEw3adLk+0rXFkSSxqW2PXWT6vvcZSOvSFXllqXoFCCPeEEHaEEHaUBkaWUB1pEw3btb841IZmkRawqG2NXfv4vPYaS5FcDsDGUrgAMZZCd7AMk6bEHTIsx0hcOjq8b5tdlyKBtJt2t3WZJJ7uf2abodKi4+TA528pp/AIgG0icnE18Px7ADzYmmaRDkK75hfaNuc0/YYeQpgVkQ8D+B6AIoAvhhC6N/A7qQvaNb/QtvlnSTlFQwjfAfCdFrWl2UZkFnl5RKuFNWUZ281tGzLLaupUw7/gxz9qFB28rGLKfAN0Wary1gzTW2nXtkgV9UprnYhblLJJoix13ZYix3TFM5tFQjqRlO0asWtq20oD11XZruZZ1rRZxsmBakQIIQRgh04IIbmBHTohhOSEJWnoHUPpYI3o5FqXFJ8fKCy83WJlKS3e691G+yyEzLLgktHZ/RJ1eH2ws+6PraVZDTWpmYb6tlsMfZ2b1bcbsJW+B3thBquhEZ08JOyT+jaT2q8RtE28fdR1l5rvYepZTunyy/A6zTd0QgjJCezQCSEkJ3Sv5JJyR1TDtqSsMuvKyrGsUHbSyawqm7HjQilX1HauLDXyK4pbj/8/KyX7v7TSJ5llWoIJzmJGqvFyjNGK2jc0b9pVsV7pJDX8Tg3bK9Z2Qa830mYnc0hqaF4sZpcVEvuljqk364XZtuZ5bUBWKVcyy/Qz6e2a2i95j9XYTj1QBf9MZpfp/RqSY1rwes03dEIIyQns0AkhJCewQyeEkJzQvRq6QirZ64VZr62FzDKtjRem7UEL09GPUaZm7DGnohgv07asRr/TGlmf8z8c6I+7DdpLXxnoU8t2v3J//L8bnC9kxRwmW5NblqiQGWhdP6nxNuJuWK8uWrb+qOVN58Tl1f22rJT9TcN/Gymdjnbv2/OirV6fY9G9I6n7Q2q0VmVnv1+KbndH9Y9E4juG+XbhymS2nFkGVRZm3ccyve7uB3PNAUipFDfdstGUzYzG+2XwmUOubcoG7jkP+qNXSl+3JVZTb/JVm2/ohBCSE9ihE0JITugeycW7JdU949PJKtr90MsqU3H4VZi0w7TC5HQ85tlJ2zS1Xpm0ZTVDOjUclvPPM0W7/zwmUr743+1+/cfccTMou//B1mUue/ZpOyMNtsRtsVxJlLlr3hdv44nXbTJFsyNx+Ns37q75ySl1TNvm8iorz1SU7DV7xRZTVnzs2fllKWU/Uv6qGAnGn29KgulGySXlmqjlEn+eWpLyZVpymbFSp5FZvAw6MBAPcbm11cyYtev0qJIznZtxuaRdie19NfzYS8jCTN5O9bDexi2wK9/QCSEkJ7BDJ4SQnMAOnRBCckL3aOieRIRDrZPXaOhKN9eaOQAUJ6LWNrNm0JSVtM536Iwpq5w9G5s1PW3KUhEOD/7Wubb+M7GO/mPjtt1Kw0ciZIAU3bcGte5dGo1ou8yya926eWqqd3IauHJTcy5sZ39F6Ztut5F9p+LKc1b3rEyo7xbBHrNveNisz169Le7Xb93U+tS3kvCyc28T9c4k7v1Jfwvoa+BR1NemG/X0BkIzGN181n0bUbp5cBq6DA3Fzdw3jbPnDaBeSuPaHdaWTY9FO0+us/YZWhXvD5mYMmUmBkebs2TxDZ0QQnICO3RCCMkJnZVcEokqTERFn4wiNVNURUMszNgdj78xug0evdG6CW78Tixb86ydCahdA2XADef8TLBtF80vnrnU1r/hx3Hb4qFjpiyMjsQVHwmykn2d6qYbh+aeeqMmrl9jd1PuZcMvT9jd9j638PEBiIma6Gb7TdlhdOlwlG6mL7D1z24YjUc5cNC1u7LwMgA7lTchUxS70Hap2aANJapIRE1UszqnrrzAFE2eU0ImCZmjdNbWMfzw0wvWBwDTN14RD+l7Su2eOm7vuY4kI6/CN3RCCMkJ7NAJISQnsEMnhJCc0D1uiw3ITkZHTmYssnrZidfH5aEnh0zZ2u/viSuDXieP2pr4CIpOLzt0XdRXw4B1tVr3xGk0g/ZGrImaKBnLC613gmYTONdsmq3D9p2J3yqKR06aMvMVoybTkLr9/bcQz8lou3DhWts2lW1KGnE/tI2x66nkz73wPSSDGvdWve4jZV6wYX55fLPVzGu+qyl0yIu+SVvf8E+ft9sqN2QZHbVlPfi624NNJoQQshCLdugi8kUROSIiT6q/rRORH4jI3urvtaljkO6Dds0vtO3KpZ7x4X0APgfgS+pvdwLYGUK4W0TurK5/vPXNW5zapLNq2blBzW6OrmgjL9qZojrQfc2sPe2i5Ie7rv7JdXr47aIfavczX4eOvObd1HSdbigedEKN1DC9lvvQabvWSAd1SjBHrawyePDI/HLFRd6TZJLm7GS+ISEHedlLJ8oonXuOLXxVuaf6maL10rjEch86bdsm8Nd8YtNwxpZWDim6YIul16IeM/L4K3a/GTfTWz/3a6zkoiXLmmfLPJMNRE1MJQZvAYveYSGE/wZw3P35FgD3V5fvB/CuFreLLDO0a36hbVcuzWro54YQDgJA9ffGrA1F5HYR2SUiu2amxrM2I91BU3adLp/N2ox0D3XZ1th1ls9rr7HsH0VDCPeEEHaEEHaUBkYW34H0BNqu/cXsoTHpLYxd+/i89hrNui0eFpFNIYSDIrIJwJFF91iMhMuddx9KJTzWmvrsmHVNlNS8ea1p91sXKZPs2UeLcxHiJrbHkAJhyiWk1W6UPluJ0uFq9LpUsueES2MTiaGXwa7pbw5N4ZICB5/pJqN+cVO7jZ65SLTIoCJu9k1Ym88OxeOGYevyKindvL3uh623bTOkwgI4G6zaFUNwDF5kBxR9h1+LK5M2TIN2RVz0btPf2fqzwwnU9EGJ7y/Grgkbh5pvOpmb1k2zh3gQwK3V5VsBPLD0ppAugHbNL7TtCqAet8WvAPgRgCtE5ICI3AbgbgA3iMheADdU10kPQbvmF9p25bKo5BJCeG9G0dta3Ja6Sc4UVZEKS6fsR53KeJzFWbG5Ys1wq0by0IlsfRB+P9ybif8jBw7ZyysqwmIYc7PSlKticHJMzdCsBXSjXQ0+8l6ziaeTdejkINkJNQAgqHWfzFhCyjUy2+V0uei4bRu5X80MYHtdw3h8tgpP7LNlWj7zUprG29HLdTNxvbzGSrRmO//qW1xe98Nm4UxRQgjJCezQCSEkJ7BDJ4SQnNDZaIt62nqNGL7gYu0hGsiOMvZUPN3T19osI/tXb55fvvC7Nipi38tK+3aZbMIZq9PL5Nb55ZlVtv7nPnDZ/PKGx6yWN/K8TUydRY3nZeeSo3QOn/lH6+0pN7GUTu40e6+11t+2FWIQ78ZXUe6hjWQsqru+xLun08kr21TSaJ9ZyW1bUN/cpkdcd6iaOvCaPVDhyIm40myEzWWAb+iEEJIT2KETQkhO6J6xQiOY2WWpMju82/yfR+eXn1tjo+LNbo8zAfdtt25QI49snV8unbHHnF7tQ+/FBoV+u+2aZ2LZqidftc0esdEfNTVD2IwycVNDe27wn0w0nJgNqofj3t1R413YlMtacGWouKH5cAxvUB5wCaW1Z+JxK9clbaASqcxsttFsD10X6yvafObY/JCNINjTpOyl8IlDRLn9nt1+rimbGov2qUkin8q34l5vi9Nx47Fd7prXKx0lkn9LwUVjbUFGGr6hE0JITmCHTgghOYEdOiGE5ISe0NDFux6Zqf8JLcsnBZ6ILoeXfOmAKXttx6b55RPbrEY6vTouzw7aY5a99K0ad+nXbHaU/v2HY7OH69fMzVrNFPjumXbc9Xi9NqHLFwatfQprVfLvmmiYcX1ym9VzCzOxjskNNt7ExLr4PnX6YnvIsaviN5afXv1vpuydD3V5bopm3Rb9FP6N8cE7ucNe13J/vOZa6waAvol4zQePWDfjyY02GmZFJfiuiU6q7DpzwXpTVNp/KK6kzrcR9O3Y5Ks239AJISQnsEMnhJCcwA6dEEJyQk9o6KkQuSm/Uj9dWGce8iFQxx6JfqZjD1vtO0yqLEQuXO70m19v1vf/Xqyj/4DP05sg4T9vaCAEq742TWQvWn6a1R69T7LyC5cRmw6vMhpDopZHrIZtMkg5ZoZd9ppEBnh9bcvDPpRrXNd6LQCMvhTT1Z/zmL2vKt+O5/HOQ7dktqXnUbp5ZbPVqY9eHTV0/x1t7Z44b6T0yglTVn456tviQlEX3/pGu+2AzhRm69A6/fErbWjddZWo6ZdeaeA5X2b4hk4IITmBHTohhOSErpVctFyQmvpeu6PO/OPcFkvxdIO4qd5qGO2Tvpo1l/R3fJMbmpfUdPI+nyTa1WkKm8tsY7IZdXCuv5YhpBXZhZwL28xl0a10ZrW95hU1NPZm1W2pMfl0HMdr98K5je2qPj8vXw0ciZE7i6+etPtNqKieXirSkqA736KWClKRBrsESUiG5n5w53L66mjXU1vs9SnMxP1GDlv79O07OL8cJmxshMLqVfPLZ3/tUlM2PZqd3Wh0v42cOr0uuq6evsC27cg1I7EtE1bmG9sfJdtKyZ7v2Q3xOOv/r/VSTfffKYQQQuqCHTohhOQEduiEEJITukZDT2biSZYl9FqfvUavFxP/y1xYS619itM6T2y3dZSORn1XhxqoqT/RtuDKUvptina6Kjalmyf2OXXtFrOuQy74qd6Dx6L7X9/4jCkrHj2lDuLC556Nrm9h/KwpK79pu10fjHYvTrnsNfteisfx3z9SGek1XZQ5vi7qi3oLwF6TE9ecZ8rObInPYZ+VsLH+yfj9of+5I/aYOqPURuvueOYNG+aXZ4btc14atw0ffTRq8ZXDNqT1kHKBHTyw0ZSd3j42v3z8CmvjY2+K69e8YZ8p+/olO+eXb77hD9Bq+IZOCCE5gR06IYTkhK6RXDzGDcrPFE3JLKnIbjrannch1Nv67DV6eDdko/BNn2O33fDjONzSM0wBQIbUbDMn+RgXy5qZiEpy8UN6yVjuVup0QZ1cY6/P0Il4nUcfPWTKwsnX4oqTcaw6l30/yMVW4vHuZtrFsbTHRurU2Y6k4B4pyXaHTboj9poEk3juTlwbZ1WevNyec7/y8tz4M5sove+lmGGssna1KZu48Py4vN65g6p7YM0em0GqsM/arjweZR0/qxQqo5UcsnLM2LHY8NIZe+8cHIjtefTA+abs+s9/YH55LewM11a8XvMNnRBCcsKiHbqIbBGRH4rIbhH5hYh8pPr3dSLyAxHZW/29drFjke6Bds0ntOvKpp439FkAHwshvA7AdQA+JCJXArgTwM4QwjYAO6vrpHegXfMJ7bqCWVRDDyEcBHCwunxaRHYDOB/ALQDeWt3sfgD/BeDjy9JKPw07pS8qmbwmmp7WxNQyAKOph2nr+oap6H44cdWFtmll25b1P1eioHdZ06EHSi5zvFoPLiqfWXf/gnWEuIZcGlts15ZP/XfoaeDhtNVatW5eo5P77yGKcEVMEzQzZjPZ+O80pWdeVvVbXRYlFYrA21yvp9xoGwj3kKIrnlfH4V+P13LgqCtUl/nUJXYKvWyNz9rsgL0+BRVxVS8DwJqnoqtq5fGnTVnZZakybsj++4e2ScKuw09bl8pLns0O6YCi+q62DN9JGtLQRWQrgKsB/ATAudWb55c30caMfW4XkV0ismtmanyhTUiHWapdp8sTC21COsyS7TrL57XXqLtDF5FVAL4B4I4QwqnFtv8lIYR7Qgg7Qgg7SgMji+9A2kor7NpfHFp8B9JWWmLXPj6vvUZdbosiUsLczfHlEMI3q38+LCKbQggHRWQTgCPZR2gcI6v4aaTGVc+7gmUf0yS1cLKKTlwRpuwMzzAdo6eNn2svmY/aKAdiImhZZR+IMBgTLIQBFzGwpBMh2JOoFLPdFmvcGBuglXatW2ZJyQzKPMUZe7zpVSpxyOXWTUzLI8VTzlX0bFyfPW+NKauoxBh9p63NC/tfsduqWaXioybq9YKLsGnO192cdc4cbpROPK+uAW49LpYHnFupkhOnR+1+RWWSkYP2eR08FEcP8qJ1Yy2fVLKnl+B825Qk4u2KUn92mYqkGry7o7ZzjZtx9ozwVlCPl4sAuBfA7hDCZ1TRgwBurS7fCuCBlreOLBu0az6hXVc29byhvxnAHwF4QkQerf7tEwDuBvB1EbkNwIsA3r08TSTLBO2aT2jXFUw9Xi4PI3sO4tta2xzSLmjXfEK7rmy6dup/KimvnpYtzmVJykrbmnEuazpJtNczC9l6ZmFbdG979Xp7zIEjTjMdjdlSwrANE1AZjppcecDuV+mP7Sn3Z4cFCD54X2rqf7dPH09MhV+/83lTNH15zGwzscle1+Kkyjy0yn6bKJSjPUrHbETF4nMxY0zluMs05HRR6U/pqToap7uvjFtc4p5L0e12BJLfRq74QnTzPH3pqCmrlOK2g0etTj60W2Ulcq6qFfWdK3jXVJVVTHzWMv/9Q9kVA851tV/dSyXv0pjIKKXvgTbbjlP/CSEkJ7BDJ4SQnNA1kkvNLEe1XnEzJ00IvQH/Pyn7lAoJNzHRcowehgGY3BQjvZXGrHtb8UXrmlhZHWe7VQada+JQbFvZRfPTMlLNTFE1aq91W1TLHRyZZ7lPNjRrNDEzr/+ZOPwuVdwMYB0N0w2/g54RHLIjMZrhNWoTmZght3dN1MnA/X4pN80WuC0uxW11SfgZy5VEkvWJ6Pa7+nEbtdAkTnezt00SCyd56Oq95GJUyJQdASuXlVzCd+2a6BK+G5nFy2wJuxpXxWV4neYbOiGE5AR26IQQkhPYoRNCSE7oGg29RmvyiZoV2tVJuyjN7add/Jz7oZpeL8NWJy9oF0fnCvnCO+K2b7nYRm97/l+uMOuzYzGuic96o9vt25ZyTTTn5LXLVBiEbiTVxjrLxGvoSs8MrkzU9G2ERGZjSbsUGl04pZOn9NSaOhMaeq+hv+W490Rz7SoJTdm7imrduuxsp7NN+fvBtCsxLd/X6cpCoizl5pyc3r/Mr9B8QyeEkJzADp0QQnJC90guHpMYOSG/OBc/M3J2wyQ9A9MHxUdFuS+5oquu2zu/fO+FD5uy3yxuN+uzIyqJRU0yiuZkFTNrtsdH5gY/HNVuhTWRCXWZj9KoZgbWm0B8sbY0Ul7n8LvhOnuVlEtjKuKkSxYT9Lbedmo95Rpbm1S9AXukZnwmjrPcrokp+IZOCCE5gR06IYTkBHbohBCSE7pXQ09QozFrlH5V8TKXktpCOVsD8zrs0btjtMW3yO12YxugzbbNS8TGTc2XpfbTZb2lu3oNMxkKQJ/bYplmNMXEfq2iFVp4i2zXsen+LcC78Zk17/KZsqVODO4jkKZIXbtmdfKaOhpoT4vhGzohhOQEduiEEJITekNySQ19CtkR9FBMySqJ+lrlG5g4TLKKHpNVGkHLBXXLL42UJStvkRzTBvv0sqxSQ+K1MaQytCQmgKIRmaUV9Mirb480kxBCyGKwQyeEkJzADp0QQnKChOVy81qoMpFXAbwA4BwAR9tWcZqV2JaLQggbWnUw2nVRaNfWsVLbUpdt29qhz1cqsiuEsKPtFS8A29I6uqn9bEvr6Kb2sy1pKLkQQkhOYIdOCCE5oVMd+j0dqnch2JbW0U3tZ1taRze1n21J0BENnRBCSOuh5EIIITmBHTohhOSEtnboInKTiDwtIs+KyJ3trLta/xdF5IiIPKn+tk5EfiAie6u/17ahHVtE5IcisltEfiEiH+lUW1oB7Wrakhvb0q6mLT1h17Z16CJSBPB5AO8AcCWA94rIle2qv8p9AG5yf7sTwM4QwjYAO6vry80sgI+FEF4H4DoAH6pei060ZUnQrjXkwra0aw29YdcQQlt+AFwP4Htq/S4Ad7WrflXvVgBPqvWnAWyqLm8C8HQH2vQAgBu6oS20K21Lu/auXdspuZwP4CW1fqD6t05zbgjhIABUf29sZ+UishXA1QB+0um2NAntmkGP25Z2zaCb7drODn2hAM8r2mdSRFYB+AaAO0IIpzrdniahXRcgB7alXReg2+3azg79AIAtav0CAK+0sf4sDovIJgCo/j7SjkpFpIS5G+PLIYRvdrItS4R2deTEtrSroxfs2s4O/REA20TkYhHpB/AeAA+2sf4sHgRwa3X5VsxpY8uKiAiAewHsDiF8ppNtaQG0qyJHtqVdFT1j1zZ/SLgZwDMA9gH4ZAc+ZHwFwEEAM5h7A7kNwHrMfZ3eW/29rg3t+A3MDV8fB/Bo9efmTrSFdqVtadf82JVT/wkhJCdwpighhOQEduiEEJIT2KETQkhOYIdOCCE5gR06IYTkBHbohBCSE9ihE0JITvh/iWvRzNmTR8sAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 对 train_data 进行归一化操作\n",
    "mean = np.mean(train_data, axis=0)\n",
    "std = np.std(train_data, axis=0)\n",
    "train_data = (train_data - mean) / (std + 1e-7)\n",
    "\n",
    "# 观察归一化后的图片\n",
    "plt.figure\n",
    "plt.subplot(1,3,1)\n",
    "plt.imshow(train_data[26].reshape((28,28)))\n",
    "plt.title('label: %s' % train_labels[0][26])\n",
    "plt.subplot(1,3,2)\n",
    "plt.imshow(train_data[16].reshape((28,28)))\n",
    "plt.title('label: %s' % train_labels[0][16])\n",
    "plt.subplot(1,3,3)\n",
    "plt.imshow(train_data[5].reshape((28,28)))\n",
    "plt.title('label: %s' % train_labels[0][5])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def CrossEntropyLoss(P, y):\n",
    "    '''\n",
    "    Computing CrossEntropyLoss.\n",
    "    \n",
    "    Input:\n",
    "    - P: the probability. N x C\n",
    "    - y: the ground truth label. N x 1\n",
    "    Return:\n",
    "    - Loss: the Cross Entropy Loss. \n",
    "    - dL: the gradient of back propagation gradient.  N x C\n",
    "    '''\n",
    "    N, C = P.shape\n",
    "    # 多分类交叉熵\n",
    "    L = np.zeros((N, C))\n",
    "    for i in range(N):\n",
    "        L[i][y[i]] = -np.log(P[i][y[i]])\n",
    "    dL = np.zeros((N, C))\n",
    "    for i in range(N):\n",
    "        dL[i][y[i]] = - 1 / (P[i][y[i]] * N)\n",
    "    Loss = np.sum(L) / N\n",
    "    return Loss, dL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Affine_forward(X, w, b):\n",
    "    '''\n",
    "    Affine forward.\n",
    "    \n",
    "    Input:\n",
    "    - X: input data. N x D.\n",
    "    - w: weight. D x C\n",
    "    - b: bias. C x 1\n",
    "    \n",
    "    Return:\n",
    "    - out: the output of Affine layer. N x C\n",
    "    - cache: the variable that used to back prop.\n",
    "    '''\n",
    "    out = np.dot(X, w) + b\n",
    "    cache = (X, w, b)\n",
    "    return out, cache"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Affine_backward(dout, cache):\n",
    "    '''\n",
    "    Affine backward.\n",
    "    \n",
    "    Input: \n",
    "    - dout: the gradient from upstream. N x C\n",
    "    Return:\n",
    "    - dw: the gradient of w. D x C\n",
    "    - db: the gradient of b. C x 1\n",
    "    - dX: the gradient of X. N x D\n",
    "    '''\n",
    "    X, w, b = cache\n",
    "    N = X.shape[0]\n",
    "    \n",
    "    db = np.sum(dout,axis=0) / N\n",
    "    dw = np.dot(X.T, dout) / N\n",
    "    dX = np.dot(dout, w.T) / N\n",
    "    \n",
    "    return dw, db, dX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Softmax_forward(X):\n",
    "    '''\n",
    "    Compute the softmax forward.\n",
    "    Input:\n",
    "    - X: the scores. N x C\n",
    "    \n",
    "    Return:\n",
    "    - prob: the probability. N x C\n",
    "    - cache: the variable that used for back prop.\n",
    "    '''\n",
    "    (N, C) = X.shape\n",
    "    # 修正 X, 防止指数爆炸\n",
    "    X = X - np.max(X,axis=1).reshape(N,1).repeat(C,axis=1)\n",
    "    prop = np.exp(X) / np.sum(np.exp(X),axis=1).reshape(N,1) \n",
    "    cache = prop\n",
    "    return prop, cache"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Softmax_backward(dout, cache):\n",
    "    '''\n",
    "    Compute the back prop gradient.\n",
    "    \n",
    "    Input:\n",
    "    - dout: the gradient comes from upstream. N x C\n",
    "    - cache: the variable that used to compute the gradient.\n",
    "    \n",
    "    Retuen:\n",
    "    - dscore: the gradient of the scores. N x C\n",
    "    '''\n",
    "    \n",
    "    p = cache # N x C\n",
    "    (N, C) = np.shape(p)\n",
    "    # 计算比较复杂, 是一个比较复杂的张量, 我也不知道怎么向量化表示\n",
    "    #dp = np.zeros((N, C, N, C))\n",
    "#     for i in range(N):\n",
    "#         for j in range(C):\n",
    "#             for m in range(N):\n",
    "#                 for n in range(C):\n",
    "#                     dp[i][j][m][n] = p[i][j] * (i==m) * (j==n) - p[i][j] * p[m][n] * (i==m)\n",
    "    # dp = np.einsum('ij,imjn->ijmn',p,np.einsum('im,jn->imjn',np.eye(N),np.eye(C))) - np.einsum('ijmn,in->ijmn',np.einsum('ij,mn->ijmn',p,np.ones((N,C))), p)\n",
    "    dp = np.einsum('ij,imjn->ijmn',p,np.einsum('im,jn->imjn',np.eye(N),np.eye(C),optimize=True),optimize=True) - np.einsum('ijmn,im->ijmn',np.einsum('ij,mn->ijmn',p,p,optimize=True), np.eye(N),optimize=True)\n",
    "    # 使用 einstein 求和约定计算张量，dot 真要命\n",
    "    dscore = np.einsum('ij,ijkl->kl',dout,dp)\n",
    "    # dscore = np.einsum('ijkl,kl->ij',dp,dout)\n",
    "    # dscore = np.einsum('ijkl,ij->kl',dp,dout)\n",
    "    return dscore"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LogisticRegression(object):\n",
    "    def __init__(self, input_dim, num_classes, lr=1e-3, batch=50, weight_scale=1e-2, weight_decay=1.0, reg=0.0):\n",
    "        '''\n",
    "        Initial the network.\n",
    "        \n",
    "        Input:\n",
    "        - input_dim: input dimension.\n",
    "        - num_classes: the prediction classes.\n",
    "        - lr: learning rate.\n",
    "        - batch: mini batch that used for SGD.\n",
    "        - weight_scale: weight scale.\n",
    "        - b: bias.\n",
    "        - reg: regularization parameter.\n",
    "        '''\n",
    "        self.w = weight_scale * np.random.randn(input_dim, num_classes)\n",
    "        self.b = np.zeros(num_classes)\n",
    "        self.lr = lr\n",
    "        self.batch = batch\n",
    "        self.weight_decay = weight_decay\n",
    "        self.reg = reg\n",
    "        self.Loss = []\n",
    "        self.Acc = []\n",
    "    \n",
    "    def train(self, data, labels):\n",
    "        '''\n",
    "        Compute the prediction. \n",
    "        \n",
    "        Input:\n",
    "        - data: the input data.\n",
    "        - labels: the labels of the data.\n",
    "        '''\n",
    "        acc = 0\n",
    "        iteration = 0\n",
    "        while acc < 0.97:\n",
    "        #while iteration < 200:\n",
    "            index = random.sample(range(len(data)), self.batch)\n",
    "            mini_data = train_data[index]\n",
    "            mini_labels = train_labels[0][index]\n",
    "            # forward pass\n",
    "            scores, cache = Affine_forward(mini_data, self.w, self.b)\n",
    "            prop, cache2 = Softmax_forward(scores)\n",
    "            Loss, dL = CrossEntropyLoss(prop, mini_labels)\n",
    "            \n",
    "            # Compute prediction and accuracy\n",
    "            pred_y = np.argmax(prop,axis=1)\n",
    "            acc = np.sum((pred_y == mini_labels))/len(mini_labels)\n",
    "            self.Loss.append(Loss)\n",
    "            self.Acc.append(acc)\n",
    "            \n",
    "                        \n",
    "            if np.mod(iteration,10) == 0:\n",
    "                print('iteration:\\t%s\\t\\t\\tLoss:\\t%s'% (iteration, Loss))\n",
    "                print('Accuracy:\\t%s' % acc)\n",
    "            \n",
    "            # backward pass\n",
    "            dscores = Softmax_backward(dL, cache2)\n",
    "            dw, db, _ = Affine_backward(dscores, cache)\n",
    "            \n",
    "            # update\n",
    "            iteration = iteration + 1\n",
    "            self.w = self.w - self.lr * dw\n",
    "            self.b = self.b - self.lr * db\n",
    "        \n",
    "    def test(self, data):\n",
    "        '''\n",
    "        Test Stage. This part is useless for this homework, I'll finish this if have time rest.\n",
    "        '''\n",
    "        scores, cache = Affine_forward(data, self.w, self.b)\n",
    "        prop, cache2 = Softmax_forward(scores)\n",
    "        \n",
    "        pred_y = np.argmax(prop,axis=1)\n",
    "        return pred_y\n",
    "    \n",
    "    def history(self):\n",
    "        '''\n",
    "        Return the log.\n",
    "        '''\n",
    "        return self.Loss, self.Acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iteration:\t0\t\t\tLoss:\t2.6307473565334503\n",
      "Accuracy:\t0.113\n",
      "iteration:\t10\t\t\tLoss:\t0.30232980314862423\n",
      "Accuracy:\t0.914\n",
      "iteration:\t20\t\t\tLoss:\t0.30977954237103195\n",
      "Accuracy:\t0.914\n",
      "iteration:\t30\t\t\tLoss:\t0.22483821893055797\n",
      "Accuracy:\t0.934\n",
      "iteration:\t40\t\t\tLoss:\t0.18784810738704721\n",
      "Accuracy:\t0.946\n",
      "iteration:\t50\t\t\tLoss:\t0.2073503098041697\n",
      "Accuracy:\t0.943\n",
      "iteration:\t60\t\t\tLoss:\t0.20878012554384653\n",
      "Accuracy:\t0.939\n",
      "iteration:\t70\t\t\tLoss:\t0.20084784410746312\n",
      "Accuracy:\t0.947\n",
      "iteration:\t80\t\t\tLoss:\t0.19398323854928165\n",
      "Accuracy:\t0.948\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAEICAYAAACzliQjAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xl8HWd97/HP72w6R8vRbluWZEve19hOHJN9dSBJacIWCC00YWmAAm1oaRvoLVzCbWlLL0tfcIFAAqHQhBAChJCQ2GTfHMtxFlvyKlu29l3naDn7c/+YkaLVVmxLR5nze79eelmznXnOePSdZ56ZeUaMMSillMoMrnQXQCml1OzR0FdKqQyioa+UUhlEQ18ppTKIhr5SSmUQDX2llMogGvpKKZVBNPRVRhORoyKyNd3lUGq2aOgrpVQG0dBXahIi8pcickhEukXkQRFZaI8XEfmmiLSLSJ+IvCYi6+xp14pIrYiERaRJRD6f3m+h1EQa+kqNIyJXAF8D3g+UAQ3AvfbktwOXACuAAuADQJc97U7gE8aYPGAd8PgsFlupafGkuwBKzUF/DtxljHkZQES+APSISBUQB/KAVcBLxpi6UcvFgTUi8qoxpgfomdVSKzUNWtNXaqKFWLV7AIwx/Vi1+XJjzOPAd4DvAm0icoeIBO1Z3wtcCzSIyFMicv4sl1upk9LQV2qiZmDx8ICI5ADFQBOAMea/jDHnAGuxmnn+3h6/0xhzPTAP+A1w3yyXW6mT0tBXCrwi4h/+wQrrj4jIRhHJAv4V2GGMOSoi54rI20TECwwAESApIj4R+XMRyTfGxIEQkEzbN1JqChr6SsHDwNCon4uBfwZ+BbQAS4Eb7XmDwA+x2usbsJp9/tOe9mHgqIiEgE8CH5ql8is1baIvUVFKqcyhNX2llMogGvpKKZVBNPSVUiqDaOgrpVQGmXNP5JaUlJiqqqp0F0Mppd5Sdu3a1WmMKT3ZfHMu9KuqqqipqUl3MZRS6i1FRBpOPpc27yilVEbR0FdKqQzimNDvCEc571//yP27GtNdFKWUmrMcE/o5WW5aQxE6+6PpLopSSs1Zjgn9gNeNxyX0DcXTXRSllJqzHBP6IkJ+wEtIQ18ppabkmNAHyA94taavlFIn4KjQD2roK6XUCTku9LV5Rymlpuao0M8PeAlFEukuhlJKzVkOC32PNu8opdQJOCz0rTZ9fRuYUkpNzlGhH/R7SaYMAzF9H7VSSk3GUaGfH/AC6MVcpZSagiNDX9v1lVJqco4K/aCGvlJKnZCjQl+bd5RS6sQcGfpa01dKqck5KvS1eUcppU7MUaGfl+VBRJt3lFJqKo4KfZdLyMvyaFcMSik1BUeFPkB+tva0qZRSU5nx0BeRShF5QkTqRGSviPzNTK5P+9RXSqmpeWZhHQng74wxL4tIHrBLRLYZY2pnYmVBv4a+UkpNZcZr+saYFmPMy/bvYaAOKJ+p9ekrE5VSamqz2qYvIlXAJmDHTK1Dm3eUUmpqsxb6IpIL/Aq41RgTGjftFhGpEZGajo6O01qPvjJRKaWmNiuhLyJerMD/uTHmgfHTjTF3GGM2G2M2l5aWnta68gNeookUkbh2r6yUUuPNxt07AtwJ1BljvjHT6xt+KjcU0dq+UkqNNxs1/QuBDwNXiMgr9s+1M7Uy7XRNKaWmNuO3bBpjngVkptczLOi3vpK26yul1ETOeyJ3pKavXTEopdR4jg19rekrpdREGvpKKZVBHBf6Qb2Qq5RSU3Jc6HvdLrJ9bq3pK6XUJBwX+qBdMSil1FQcGfra06ZSSk3OkaGfH/DqE7lKKTUJR4a+1ema3qevlFLjOTL0tU99pZSanCNDPxjwaJu+UkpNwpGhnx/w0h9NkEim0l0UpZSaUxwb+gDhiLbrK6XUaI4OfW3iUUqpsRwZ+kG/vkhFKaUm48jQz8/Wmr5SSk3GmaGvzTtKKTUpR4b+cPOOhr5SSo3lyNDXt2cppdTkHBn6fq8Ln9ulNX2llBrHkaEvInb/Oxr6Sik1miNDH6yuGPSWTaWUGsuxoa+driml1ESODn1t3lFKqbEcG/rz8rJo7h3CGJPuoiil1Jzh2NBfUxaksz9Gezia7qIopdSc4djQX1ueD8De5r40l0QppeaOGQ99EblLRNpFZM9Mr2u01WVBRGBvU2g2V6uUUnPabNT0fwJcPQvrGSM3y0NVcQ57mzX0lVJq2IyHvjHmaaB7ptczmTULg+zR5h2llBrh2DZ9gLULgzT2DNE3qLduKqUUzJHQF5FbRKRGRGo6OjrO2OeuXWhfzG3R2r5SSsEcCX1jzB3GmM3GmM2lpaVn7HPXLgwCUKvt+kopBcyR0J8pJblZzA9m6cVcpZSyzcYtm/cALwArRaRRRD420+scbe3CfL1XXymlbJ6ZXoEx5oMzvY4TWbswyFMHOojEk/i97nQWRSml0s7RzTtghX4yZdjXGk53UZRSKu0yIPStO3j2NGkTj1JKOT70KwoD5Ae8ejFXKaXIgNAXEdaUBanVi7lKKeX80AerXX9fa5hEMpXuoiilVFplRuiXB4kmUhxs7093UZRSKq0yIvTPX1KCxyX8YufxdBdFKaXSKiNCf0G+n3dvKufencfo7Nc3aSmlMldGhD7AJy9bSjSR4sfPHUl3UZRSKm0yJvSXluZyzboF/PSFBkIR7WpZKZWZMib0Af7qsmWEIwl+9mLDmPHGmDSVSCmlZldGhf668nwuXVHKXc8eYSCa4A97Wnjv955n9Zf+wDe3HSAST6a7iEopNaMyKvQB/uqypXT2xzj/a3/kkz97mfZwhIuWlfDtPx7kHd96mif3t6e7iEopNWNmvJfNuWZLdRHXrl9ARzjKRy+s5u1rF+B2Cc8d6uSff7uHm3+8kwuXFfPxi5Zw6YpSXC5Jd5GVUuqMkbnWnr1582ZTU1OTlnVHE0l++nwDdz57hNZQhKWlOXzy0qW875wKRDT8lVJzl4jsMsZsPtl8Gde8cyJZHjd/eckSnvnHy/n2jRsJ+Nz8/f2v8fG7a+g6xfv7dx/r4WuP1DEU0+sFSqn005r+CRhj+MnzR/naI/sI+r18/YazWL0gSGd/lM7+KO3hKO2hCG2hKPFkik9fvozKouyR5V9v7OPPfvgi4WiCLVVF3HnzZvL83jR+I6WUU023pq+hPw37WkP89T27OdA2ed89+QEvsUQKn8fFN96/gStXz+dgW5j3/+AFsn0e/vLiav7P7+tYuzDI3R/dQn7AS01DDz99oYF4IsVt16yiqiRn5POGm5mCAQ/v31ypTUtKqZPS0D/DIvEk9+9qRASKc7IozfPZL1734/e6OdY1yKd+vou9zSFuvqCKh19vAeCXnzyfxcU5bKtt49M/f5nFxdl43S5qW0IE/R6MgXgqxd+/YxU3X1DFUwfauf13tRztGgTg2vUL+Lf3nkXwFM4QmnqHiCVSVI86oCilnElDPw0i8SRf+d1e7nnpOAXZXu77xPmsmJ83Mv3Zg53c8t81VBZmc9MFVbxr00JCQwm++OvXeXxfOwuCflpDEZaU5vDlP13LvpYQ//HofsoLAnznzzZxVkXBpOs1xow5G2jqHeI7jx/kvppGkinDtesX8LmtK1g+qiynqz0U4cUj3QT9Hi5Zrnc5KZVuGvpp9MS+diqLslk2L3fCtEg8SZbHNSakjTH85pUmvv9kPe85u5yPXFiNz2NdY9/V0M1n/2c3zX0RVi3I47wlxZxbVUTPYIyao93sPNpDWyhCRWGAxcU5BANeHt3TCsCfvW0RQb+HO589wmA8yXUbFvLesyt425IisjxjXxIfT6Zo7YvQ3DtES1+Ezv4oPYMxugfiDMYSeFwuvG4hkTLsPtbD4Y6BkWVXzs/jk5ct4Z1nLcTjEroGYrT2RSjL91OcmzXt7ZZKGTr6o8zLyzqtJq3xB0GlMoGGvoP0DMT4+Y4GXqzvpqahm0jcehlMaV4WW6qKqCgK0NgzREPXAK19Ebauns9nr1xOeUEAgO6BGD946jA/faGBoXiS3CwPl64oJRjwcqx7gKOdg7T0DZEatyt4XEJhjo9sn5tE0hBPpkgZWFce5IKlxZy3pJj6jgG+9+Rh9reFCfo9ROIpYvbLakTgrPJ8Ll05j62r57G+PH9CGB9sC7Otro2aoz3UHO0mFEmwrjzIRy+s5p1nLcQl8NzhLn79ciOvHO+luiSH1WVBVpcFKS8MUJjtoyjbR+9QjO117WyvbWPn0W42Vhbw/nMr+ZP1ZeRkTXwcZVttG197pI6A183HLrLW5fO4GIgm+P1rLfzutWaMsbZxaV4W8/KyqCjMpqIwQGVhNvnZJ29u6x2MkZvlwePWm+TmkqbeIXY19HDlqnmT7htvVipleKy2lSyvm8tWlKatwqGh71CxRIralhBF2T4qiwJvageLxJM8d6iT7XVt/LGunVgyRVVxDlXF2Swqyqa8MMDCggBl+QFK87II+j3T+nxjDE/sb+fRPW0U5HhZmB9gXl4WB9r6efJAO68c78UY64zgxi2VXLdhITuPdnP38w28UN8FwLJ5uZxbVUhlUTYPvNzEofZ+SvOss4SOcJSg38OW6mKOdw9yuKOfxPgjlG3F/FzeVl3Mc4c7qe8YIMfn5orV89lSXcSWqiKyfW6+8rtatte1sWJ+LsbAwfZ+5gez2FJdzON1bQzEklSX5JAf8NIRjtLRHyWWGPvWtcJsL9UlOVSX5FKc68Mlgttl/f/sb+unriVERzhKSW4W7z27nBs2V7C0NJeGrkF2Hu3m1cZeovEUbpcgIpTmZXH+kmI2LSrA77XOwobPfDrCUUKROKGhOB39MY52DlDf0U9D1yBul1CU46M410eWx03PYIyegRihSIJsn5uCbC8FAR95fg9+r5uAz43X7WIolqA/mqA/mmTdwiA3bllEfuD07yyLxJPsbQ7hdYu1Pq+b3CwPwYAXt90EGIkn6QhHaQ9HCEcSDMWSDMaSBHxuls3Lpao4Z+RMd1g8meJ49yBHuwZo6YtYlYtEingyRU6Wh5JcH6W51gF6YUFg0jA/1B7me0/W89tXmkikDCW5WfztVSt4/+aKkQNzJJ5kIJqY9hnqi/VdfPWh2pF3cJ+9qIB/uHoV5y0pHrNNvG7XyPefKRr6as7oGYjxyJ5W7t15jNca33hXcXlBgA+dt5j3nVMxEvBghd0zhzr57xcacLvgXRvLuXzVvJEwjCaSHGrvpz0cpWcgRvdADJ/HxWUr5rGo2Lpl1hjDy8d6+MXO4zy5v4P28BvPWQS8bm7dupyPXlSNxyU8eaCDHz1Tz2uNfVy9dgEfOLeScxYXjhzwjDH0DsZp6h2isWeQY92DHO0a5EjHAEc6B+gdipFKQSKVwuNysWxeLqvLgiyfn8vLDT08vq+dRMqQH/DSN2T18Jrn95Cb5SFlDMkUdA9ESRnI8rhYuzBI71Ccxp6hCQcbAL/XRXVJLlXF2aSMoWcgTtdAlGgiRVGOj4JsH0G/h6FYkt6hOD2DMQaiVrgOn4ll+6wwzvK6ON49RI7PzQfOXcS7N5UTS6boHYzROxhnIJZgMJZkMJpgKJ4kmkgRjadIGsOyebmsL89n3cJ8jnQN8Mua4zz4ajPhSGLS/SAvy4PbLfQOnriXW7dLKC8IIMLIGWbXQIzkFAf6yeQHvCwI+nG7hJQxJFKGwx39ZHlc3HjuIi5ZUcL/e+IwNQ09LJ+Xy5LSHA609XO0awBjYElpDhcsLeb8JSVkZ7mJxq1tNxBLEI4kCEfi1LWEeXxfOwvz/fzD1asYiif51vYDtIWirC/PJ5ZI0dI3RCiSwCVQnJtFaW4WZfl+qktyWFKay9LSHDYuKpjQ3HoqNPTVnLS3uY8/7GllXXk+W1fPn/HaD1ihfbx7iJeOdnOse5APnFs50vQ1GzrCUX6z2zp7Oasyn3OrilhWmjvm4ncoEuel+m6eP9zFnqY+SvJ8VBRmU1kYoDTPT37ASzDgoTjHamo6nQvn46957Gnq40fP1PPQay1TnkGBdbDxe934PW4MhrZQdML0a9aVcfW6BXhcwlDcqsH3RxL0DcXpG4qTSKWYn+dnfr6feXlZ5Pm9ZPvcZPvchCMJDrX3c6i9n2Pdg4gwci2pONdHdUku1SXZlBdkE/C68XmsaeFIgq6BKB3hGO3hCM291rWp1lAEYwwuEVwirFyQx1+cv3ikFm+M4dG9rXxr+0FiyRQr5+exckEefq+bHfVd7DjSzeAUD1W6BIpysrj5gsV8/OIlIxWSSDzJT184ymN72yjK8bEg38/8oJ9oPEl72Dpra+od4kjnAFH7gJ7n93D12gX86YaFXLC0+JSbAzX0lVJvSnPvEDUNPQT9HgqyfRQEvOT6PWT7rKAff6DpGYixp7mP15v6KM7xcc36slO6tXiuiidT1DaHSKTMyAEvx+chz94mp9N2n0oZmvuGqGsJ84c9rTy2t5Vw1Lqe9dBnLz6lz9TQV0qpt4hIPMlTBzoYjCV496aKU/qM6Yb+rPSyKSJXA98G3MCPjDH/NhvrVUqptwK/18071i6YlXXN+L1kIuIGvgtcA6wBPigia2Z6vUoppSaajRuItwCHjDH1xpgYcC9w/SysVyml1Diz0bxTDhwfNdwIvG2qmXft2tUpIg1TTZ+GEqDzNJZ3Gt0eY+n2mEi3yVhv1e2xeDozzUboT3aJe8zVYxG5BbjFHvwnY8wdp7wykZrpXMzIFLo9xtLtMZFuk7Gcvj1mI/QbgcpRwxVA8+gZ7JA/5aBXSik1PbPRpr8TWC4i1SLiA24EHpyF9SqllBpnxmv6xpiEiHwGeBTrls27jDF7Z3CVesYwlm6PsXR7TKTbZCxHb48593CWUkqpmaN9virHEJEnRaRHRKbfib9SGUZDXzmCiFQBF2PdGXbdLK53Vp5qV+pMcUzoi8jVIrJfRA6JyG3pLk86iEiliDwhInUisldE/sYeXyQi20TkoP1vYbrLOgP+AngR+Alw0/BIEQmIyDdEJCYicRF5VkRWicgOETkuIp0i0mv/frO9zJMi8vFRn3GziDw7atiIyKdF5CBw0B73bfszQiKyS0QuHjW/W0S+KCKHRSRsT68Uke+KyP8d/SVE5HcicuvMbCIQkQIRuV9E9tn7yfkZsn9MSUQ+Z/+97BGRe0TEb994ssPeJr+wb0JxBmPMW/4H6wLxYWAJ4ANeBdaku1xp2A5lwNn273nAAayuL/4DuM0efxvw7+ku6wx890PAXwHnAHFgvj3+u/a0XwMPARcA9wOfBsLAdvv3YmCjvcyTwMdHffbNwLOjhg2wDSgCAva4D9mf4QH+DmgF/Pa0vwdeB1ZiPbeywZ53C9btyy57vhJgcLjsM7Sd7h7+bvbfSkEm7B8n2B7lwJFR/4/32f/f9wE32uO+D3wq3WU9Y9853QU4Q/9x5wOPjhr+AvCFdJcr3T/Ab4GrgP1AmT2uDNif7rKd4e95kR30JfbwPuBzWGeyQ8AO4Ao79AXract/sg8EY/Yde/nphP4VJylTD7DB/n0/cP0U89UBV9m/fwZ4eAa3U9AOOBk33tH7x0m2yXCPAUX2Afsh4B32PuKx55mwj7yVf5zSvDNZVw/laSrLnGC3cW/CCrz5xpgWAPvfeekr2Yy4CXjMGDP86Pz/2ONKAD9WzXv4FVTFQC/W/nGYU99XRu9viMjf2c0lfSLSC+Tb6wfr4cTDU3zO3VhnCdj//vcplGW6lgAdwI9FZLeI/EhEcnD+/jElY0wT8J/AMaAF6AN2Ab3GmOFXgDkqT5xyEeqkXT1kEhHJBX4F3GqMCaXrRc2zQUQCwPsBt4i02qOzsJotPgAksZpxhtuphzfGcazmFZi4rwwA2aOGJ+vzdmQZu/3+H4Ergb3GmJSI9Ixb11JgzySf8zNgj4hsAFYDv5nyy54+D3A28FljzA4R+TZWc07Gsq9fXA9UY1UGfonVI/B4jskTp9T0T9rVQ6YQES9W4P/cGPOAPbpNRMrs6WVAe7rKNwPehRXsa4CN9s9q4BmsmnME6wLvfVhNPD/DOiDcC2zFugDcIiLFIrLR/sxXgPeISLaILAM+dpIy5AEJrFq0R0S+hNWUMuxHwFdFZLlYzhKRYgBjTCPWU+v/DfzKGDN06pvipBqBRmPMDnv4fqyDgJP3j5PZChwxxnQYY+LAA1jXfQpG3ZnlqDxxSuhrVw+AWFX6O4E6Y8w3Rk16kDfuaLkJq63fKW4CfmyMOWaMaR3+Ab6D1etgGfADrFq3FwgAT2H19Hot8NdYt3q+gnWBFeCbQAxow2p++flJyvAo8AjWhfMGrAPN6Oafb2AddB4DQlj/R6Nf0ns3sJ6ZbdrB3i7HRWSlPepKoBZn7x8ncww4zz7AC29skyeA99nzOGqbOOaJXBG5FvgWb3T18C9pLtKsE5GLsGq4r/NGG/YXsdr17wMWYe3kNxhjutNSyDQRkcuAzxtj3ikiS7Bq+kXAbuBDxpjoiZaf4bJdgnUGUmWMSZ1s/tNc10asMw8fUA98BKvyl7H7h4h8BaspMIG1P3wcqw1/zuwjZ5JjQl+ptyK7Oe5e4FVjzO3pLo9yPqc07yj1liMiq7EuHpZhnaUqNeNOGvoicpeItIvIZHceYF+Y+i/7SdjXROTsUdNusp9oOygiN022vFKZyhhTZ4zJMcZcYIwJpbs8KjNMp6b/E+DqE0y/Blhu/9wCfA+sR/+BL2NdMNsCfDnTHu9WSqm55qT36RtjnrYf9JnK9cBPjXVx4EW7b48y4DJg2/AFIRHZhnXwuOdE6yspKTFVVSdanVJKqfF27drVaYwpPdl8Z+LhrKmehp32U7Iy6h25ixYtoqam5gwUSymlMoeINExnvjNxIXeqp2Gn/ZSsMeYOY8xmY8zm0tKTHqiUUkqdojMR+lM9DatPySql1DQd7RzgxfquGV/PmQj9B4G/sO/iOQ/oszttehR4u4gU2hdw326PUxksmkjyw6frqWvRm1XmEmMMdzx9mOu/8yw/feEo/dHESZcBSKYMB9vC7G3uo28oPrOFnCGtfRF6BmInnc8YQyp15p9reuV4L3/1811c/n+f5H/9Zg8z/ezUSdv0ReQerIuyJSLSiHVHjhfAGPN94GGsx9kPYfUF/hF7WreIfBWriwSA2zPpKT81USgS5xM/3cUL9V343C7+4eqVfPTCalwu53YI92bUd/Tz691NPH2gg2yfh9K8LErzslhUlM3qsiCryvII+r1nfL2ReJIvPvA6D+xuoizfz5d+u5ev/2E/79tcwfryfIJ+L8GAtd6OcJTO/ijNfUO83tjHa419Yw4QeVkeFhVnc87iQt5WXcyW6iKyvC46w1E6wtYDrWdVFBDwuU9aro5wlLqWEPtbw2R5XdY2WJBH3pvYBr2DMWpbQuxrCVOY4+WadWX4vda6w5E439h2gLufP4rbJVy+ch7vObuCy1dZTczRRIr+SIKdR7t56kAHzxzsJDQUZ+WCPFYvCLKuPMg71i1gXp5/ZH2plOHxfe28UN/FuvIgb6suZmFBgGTKcKi9n1eO91DfOUBoKEE4Eud4zxCvHu8lz+/hU5cu5eYLqpjpDhLn3BO5mzdvNnoh960tEk+y82g32T4PZ1Xk43W7aO2LcPOPX+JQez9fvm4tzxzo4LHaNi5YWszXb9hAecEbXdEYY9he1863th8gJ8vDP127mg2VBRPWMxRL0tQ7RGPPID63i/OWFE96AOnqj/LSkW52HOnmaNcAN5xTyTXrFoyZd09TH7UtIS5fOY/SvImv2DXGcKRzgFcbe9nTFKI1FBkJwFTKEAx4yfN7KM3N4vJV89i6ej45WR6MMextDvHr3U3sbe6jKMdHSW4WxTnWOqKJJLFEipeP9fDysV5cAucsLiRloLPfCsrBWHKkHEtLc/jYRUt47znlZHncE8p4oK2fbbWtPH+4C7dLCPqtchVk+0YOIiU5PoIBL/kBLyljuPUXr7D7WC9/e9UKPnvFMl453stPnj/K719rITFFzdbjElaV5bGxsoCNlYVk+9w09Vj/F4c7BtjV0MNQPDnpsl63sKGigHMWFzIUT9LUM0RT7xChUWcKkUSK7ilq39UlOVy4rJiLl5dywdJi3C6x1z1EQ9cA9Z0D1HcMcLijn5a+yJhlg34P7zm7gpUL8vjW9gO0h6N8cMsisr1ufvtq88iBabyCbC8XLSthXp6f/W0h6lrCdA/E8LiEq9bM58Yti2jri/CDpw9zuGMAt0tI2tuuvCBA31B85ODoc7sIBjwE/V7ys738yfoybtyyiNys07uvRkR2GWM2n3Q+Df3M8/SBDr65/QDJlGHr6vlctWY+1SU5vHC4i211bTy1vwOvWygvDFBRkM3y+blcsWoeS0pzp/zMUCTOtr1tbKtt4+mDHSNBFfC6OWdxIfUd/fQNxfn+h8/h4uWlGGO4r+Y4X/ldLZF4ks2Li9i6Zh4r5ufxvScPs+NIN9UlOYQjCTr7o9xwTgWfuHQpdS0hnjnYwXOHumjqHdsh5fJ5udxyyRKu31hOWyjCQ6+18PvXm9nTZDUl+b0uirJ9NPdFWFMW5PPvWEEknuInzx3lpaPWSajbJVy6opTrNy4kkTTUtYSoaw2xpyk00nzh97pYmB+gJDeLkjwfbpeLcCROaCjOse4hOvujZHlcXLKilGNdg+xvC+Nzu1hbHiQ0FKcjHCUUsQLA4xKyPC4qi7J596Zyrt9YzoL8N2qOxhjaQlaNt7YlxGO1bbx6vJeyfD+fuGQJ84J+Drf3U99pBe2x7kEA1i4M4vO4CA3FCUUS9A7GiCcn/1v3e1184/0buXZ92Zjx4Uiczv6Y/d0SGIx10MjNojDbh/sEZ2jxZIrXm/rYdbQHgJI8H6W5fqKJJC8d7WZHfTevN/WR43NTXphNeUGAgmzvyN0fHreLpaU5rCkLsrosSCSRtLZBc4jdx3p5ob6LwVgSERgfYblZHpaW5rCkNJdVC/JGzhDqOwe456VjPPJ6K7FkirULg/zLu9ez0a5QJJIpnjnUySvHevF5XGR5XPi9btaX57OuPH/M9zXGcLhjgPtqjnP/rsaRA9SasiCfuHQJV69bwKH2fl460k1NQw+F2V42VhaysbKAJSU5M3J2q6HvQP3RBDVawYncAAAQ1ElEQVRHu6k52kM48katKNfvYfPiIs6pKjzh6X9dS4ivPbKPpw90UFlkhdbuY70AIzWTHJ+bi5aX4HW7aLRrT539Vu1naWkOW1fPZ8X8PCoKA5QXBjjY3s8DLzfx2N5WookUC4J+tq6Zx5Wr5xOJJdlxpJsX67uIJ1N8+8ZNrCvPH1Om492D/LLmOI/VtrGvNQxAcY6PW7cu58Yti4jEk3zn8UPc9dyRkdAK+j1cuKyEtQuDVBRmU14YoLFnkDuePkJdS4i8LA9hu1a1sbKAq9bM57wlxay3/3AffLWJb247OBKQlUUBbjq/ii3VRTyyp5Vfv9xEa8iqIfo8LlbOz2PtwqBVq11UwPJ5eVMGXiplqGno4aHXmtlW28aCfD/vObuCPz2rjILsN16zGk+mcImcMDgnY4zh2UOdfHv7QWoaekbGL8z3s7osyJWr53Pl6nnMD/onLNdnH3A6+qOEIwlCdu3zomUlLJ+f96bKcSYkU+ZNf/9hsUSKXQ09vFjfhc/jsvbHggCLirIpzcs6YRNJz0CMutYQW6qK8LhP/7JmNJHkiX0dBP0ezl9aPOPNM1PR0E+TZMr6o6woDLB0XM04mTI8faCDJaU5LC7OGTNteCdeuSCPopw3wiEcifOrXY38+pVm9jT1jfyh5PnfOBXsjyRIpAwugdVlQaqKcyjJtU7nh+JJ6lrC1LWEaOmLkB/w8tkrlvHh8xeT5XHTHo7weF07hzv6uWBZCecvKR5p8xzW2DPI9to2tte182J914RT/oJsL9dtWMi7NpWzqbLglHf6492D7Gnq46LlJRPabes7+nnqQAcbKgvYUFEwaVgYY3jmYCe/2d3EygV5XLu+jMqi7AnzgRW6D7/eQo7Pw+Wr5o35vGTK8MrxXoJ+D9UlOWckGM604WYjEau5I9vnlPchqVOloT/D4skUyZQhy+NCRBiKJbl/13F++MwRjnUP4nULt25dwScuWYLH7aKha4DP//JVdh7twe0Srt+4kE9fvozSvCzu2XGMu547Qlsoiktgc1URV62eT1PvEPfvaqQ/muCsinwuXVHK26qLOXtxwZg/8qFYkpeP9bDjSDe7Grpp6bPam8ORBG6XsLQ0h9VlQdaX53PDOZXkZ5/6xcBIPElLX4TGnkGaeoYoyvFx2cp5+DxzLxiVyiQa+mdY31Ccrz+6jwOt/TT1DtHSN0TKvHFRJhpPEY4m2FhZwEcvqubRva38/rUWNlTk8/a1C/jO44fwuIXbrlnFkY4BfrajgVgiRcDrZiCW5IKlxdy4ZREH28Jss5s6vG7hnWct5KYLqkbaHd+MSNxq8xx/wU8p5Twa+mdQ72CMD9/5EvtaQ2xaVEhFQYCKwgBZXjfhiHXrVTJleM/ZFZxbVTjSvPHQa83882/20DMY55IVpfz7e9dTlm/dpdLZH+XOZ4/Q1R/lw+dVsb5ibFt3U+8Qfo+L4tyJd5IopdR4GvpnSPdAjA/9aAeH2vv5/ofP5opV89/U8p39UfY2h7hkeUnaLvAopZxvuqGvV3/GGYol6bTvbugdinH772o50jnAD2/azKUr3ny/QCW5Wae0nFJKzYSMDP32cITQUJxFRTn4PC6MsW6zu2fHMX7/egvRxBuvKfV7Xdx187lcuKwkjSVWSqkzI+NCv6s/yp/817N0hKO4XUJlYQAR4UjnALlZHt53TgUbKgsI+q0n5paU5o55WEYppd7KMir0jTH8469eo28wzu3Xr6UjHKW+Y4BQJM6nLl3KOzeU6f3OSilHy6iE+5+XjrG9rp1/fuca/uL8qnQXRymlZl3GPFFzqL2frz5Uy8XLS/jIBVXpLo5SSqVFRoR+LJHi1l/sJuB18583bNCufJVSGSsjmne217WxpynEd/5s04SOqJRSKpNkRE1/e20bBdlerl67IN1FUUqptHJ86CeSKR7f384VK+fNyd4SlVJqNk0rBUXkahHZLyKHROS2SaZ/U0ResX8OiEjvqGnJUdMePJOFn45dDT30DsbZuubNdZ+glFJONJ135LqB7wJXAY3AThF50BhTOzyPMeZzo+b/LLBp1EcMGWM2nrkivznb69rwua23GCmlVKabTk1/C3DIGFNvjIkB9wLXn2D+DwL3nInCnS5jDNtq2zhvafFpv39SKaWcYDqhXw4cHzXcaI+bQEQWA9XA46NG+0WkRkReFJF3nXJJT8HhjgGOdg1y1ep5s7lapZSas6ZT/Z3spvap+mO+EbjfGJMcNW6RMaZZRJYAj4vI68aYw2NWIHILcAvAokWLplGk6dle1wag7flKKWWbTk2/EagcNVwBNE8x742Ma9oxxjTb/9YDTzK2vX94njuMMZuNMZtLS89c2/v22jbWlQdHXlyilFKZbjqhvxNYLiLVIuLDCvYJd+GIyEqgEHhh1LhCEcmyfy8BLgRqxy87E7r6o+w61sPW1VrLV0qpYSdt3jHGJETkM8CjgBu4yxizV0RuB2qMMcMHgA8C95qxr+JaDfxARFJYB5h/G33Xz0x6fF87xqChr5RSo0zrlhZjzMPAw+PGfWnc8P+eZLnngfWnUb5T9vzhLublZbF2YTAdq1dKqTnJsY+odg3EKCsI6HtplVJqFMeGfn8kTp7em6+UUmM4N/SjCX0gSymlxnFu6EcS5Po19JVSajTHhn44miBPQ18ppcZwZOgbY+iPJrRNXymlxnFk6A/EkhiDNu8opdQ4jgz9/kgCgNwsb5pLopRSc4szQz8aB7Smr5RS4zky9MN2TV8v5Cql1FiODP3+qB36eiFXKaXGcGToD9f0tXlHKaXGcmTov3EhV0NfKaVGc2Toh0ead/TuHaWUGs2RoT9c08/Jcqe5JEopNbc4M/SjcbJ9bjxuR349pZQ6ZY5MRe1hUymlJufI0A9pD5tKKTWpaYW+iFwtIvtF5JCI3DbJ9JtFpENEXrF/Pj5q2k0ictD+uelMFn4q/RHtbE0ppSZz0mQUETfwXeAqoBHYKSIPTvKC818YYz4zbtki4MvAZsAAu+xle85I6afQH9WavlJKTWY6Nf0twCFjTL0xJgbcC1w/zc9/B7DNGNNtB/024OpTK+r0WTV9vV1TKaXGm07olwPHRw032uPGe6+IvCYi94tI5ZtZVkRuEZEaEanp6OiYZtGnpjV9pZSa3HRCXyYZZ8YN/w6oMsacBWwH7n4Ty2KMucMYs9kYs7m0tHQaRTqxcCSud+8opdQkphP6jUDlqOEKoHn0DMaYLmNM1B78IXDOdJc900bemqU1faWUmmA6ob8TWC4i1SLiA24EHhw9g4iUjRq8Dqizf38UeLuIFIpIIfB2e9yMGYwlSRntd0cppSZz0mQ0xiRE5DNYYe0G7jLG7BWR24EaY8yDwF+LyHVAAugGbraX7RaRr2IdOABuN8Z0z8D3GDHcrbK26Sul1ETTSkZjzMPAw+PGfWnU718AvjDFsncBd51GGd+UN16gonfvKKXUeI57IldfoKKUUlNzXujrC1SUUmpKjgv9cMR+KbrW9JVSagLnhX5U35qllFJTcVzoDzfvBPVCrlJKTeC80I/qW7OUUmoqjgz9gFffmqWUUpNxXDKGI3G9c0cppabgwNDXF6gopdRUHBf62q2yUkpNzXmhH9EeNpVSairOC/1oQu/RV0qpKTgu9MORBLn6qkSllJqUA0M/rs07Sik1BUeF/vBbs7R5RymlJueo0B+KW2/N0pq+UkpNzlGhr90qK6XUiU0r9EXkahHZLyKHROS2Sab/rYjUishrIvJHEVk8alpSRF6xfx4cv+yZpD1sKqXUiZ00HUXEDXwXuApoBHaKyIPGmNpRs+0GNhtjBkXkU8B/AB+wpw0ZYzae4XJPqn/kVYka+kopNZnp1PS3AIeMMfXGmBhwL3D96BmMMU8YYwbtwReBijNbzOkZfj+u3rKplFKTm07olwPHRw032uOm8jHgkVHDfhGpEZEXReRdky0gIrfY89R0dHRMo0iT64/qW7OUUupEppOOMsk4M+mMIh8CNgOXjhq9yBjTLCJLgMdF5HVjzOExH2bMHcAdAJs3b570s6cjrM07Sil1QtOp6TcClaOGK4Dm8TOJyFbgn4DrjDHR4fHGmGb733rgSWDTaZT3hIZfoKKhr5RSk5tO6O8ElotItYj4gBuBMXfhiMgm4AdYgd8+anyhiGTZv5cAFwKjLwCfUcMXcnO0eUcppSZ10nQ0xiRE5DPAo4AbuMsYs1dEbgdqjDEPAl8HcoFfigjAMWPMdcBq4AciksI6wPzbuLt+zqhwNIHf68Krb81SSqlJTatKbIx5GHh43Lgvjfp96xTLPQ+sP50Cvhna2ZpSSp2Yo6rE/VHtS18ppU7EWaGvPWwqpdQJOSv0tYdNpZQ6IUeFvtWmr6GvlFJTcV7oa/OOUkpNyVGh3x9NkKc1faWUmpJjQn/4rVl5fr1lUymlpuKY0I/EUyRTRpt3lFLqBBwT+mHtYVMppU7KMQlZmpvF3q+8A7drsk5BlVJKgYNCX0S0ozWllDoJxzTvKKWUOjkNfaWUyiBizCm/qGpGiEgH0HAaH1ECdJ6h4jiBbo+xdHtMpNtkrLfq9lhsjCk92UxzLvRPl4jUGGM2p7scc4Vuj7F0e0yk22Qsp28Pbd5RSqkMoqGvlFIZxImhf0e6CzDH6PYYS7fHRLpNxnL09nBcm75SSqmpObGmr5RSagoa+koplUEcE/oicrWI7BeRQyJyW7rLkw4iUikiT4hInYjsFZG/sccXicg2ETlo/1uY7rLOJhFxi8huEXnIHq4WkR329viFiPjSXcbZIiIFInK/iOyz95Pzdf+Qz9l/L3tE5B4R8Tt5H3FE6IuIG/gucA2wBvigiKxJb6nSIgH8nTFmNXAe8Gl7O9wG/NEYsxz4oz2cSf4GqBs1/O/AN+3t0QN8LC2lSo9vA38wxqwCNmBtl4zdP0SkHPhrYLMxZh3gBm7EwfuII0If2AIcMsbUG2NiwL3A9Wku06wzxrQYY162fw9j/UGXY22Lu+3Z7gbelZ4Szj4RqQD+BPiRPSzAFcD99iwZsz1EJAhcAtwJYIyJGWN6yeD9w+YBAiLiAbKBFhy8jzgl9MuB46OGG+1xGUtEqoBNwA5gvjGmBawDAzAvfSWbdd8C/gFI2cPFQK8xJmEPZ9K+sgToAH5sN3f9SERyyOD9wxjTBPwncAwr7PuAXTh4H3FK6E/WiX7G3osqIrnAr4BbjTGhdJcnXUTknUC7MWbX6NGTzJop+4oHOBv4njFmEzBABjXlTMa+fnE9UA0sBHKwmonHc8w+4pTQbwQqRw1XAM1pKktaiYgXK/B/box5wB7dJiJl9vQyoD1d5ZtlFwLXichRrCa/K7Bq/gX2qTxk1r7SCDQaY3bYw/djHQQydf8A2AocMcZ0GGPiwAPABTh4H3FK6O8ElttX3H1YF2IeTHOZZp3dXn0nUGeM+caoSQ8CN9m/3wT8drbLlg7GmC8YYyqMMVVY+8Tjxpg/B54A3mfPlknboxU4LiIr7VFXArVk6P5hOwacJyLZ9t/P8DZx7D7imCdyReRarFqcG7jLGPMvaS7SrBORi4BngNd5ow37i1jt+vcBi7B28huMMd1pKWSaiMhlwOeNMe8UkSVYNf8iYDfwIWNMNJ3lmy0ishHrorYPqAc+glX5y9j9Q0S+AnwA6+633cDHsdrwHbmPOCb0lVJKnZxTmneUUkpNg4a+UkplEA19pZTKIBr6SimVQTT0lVIqg2joK6VUBtHQV0qpDPL/AZ1i9XPrpWv6AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time for GD is:\t277.4945981502533\n"
     ]
    }
   ],
   "source": [
    "# SGD 算法\n",
    "import time\n",
    "input_dim = 784\n",
    "num_classes = 10\n",
    "mini_batch = 8000\n",
    "model2 = LogisticRegression(input_dim, num_classes, lr=1000, batch=1000, weight_scale=4e-2, weight_decay=1.0, reg=0.0)\n",
    "\n",
    "tstart = time.time()\n",
    "model2.train(train_data[0:mini_batch], train_labels[0][0:mini_batch])\n",
    "tend = time.time()\n",
    "\n",
    "# 画图\n",
    "Loss, Acc = model2.history()\n",
    "plt.figure\n",
    "plt.subplot(2,1,1)\n",
    "plt.plot(Loss)\n",
    "plt.title('Loss')\n",
    "plt.subplot(2,1,2)\n",
    "plt.plot(Acc)\n",
    "plt.title('Accuracy')\n",
    "plt.show()\n",
    "print('Time for GD is:\\t%s' % (tend - tstart))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test 也需要归一化，参数就使用训练集中的参数而不是test自身的参数\n",
    "# 使用自身参数会导致已知了 test 部分信息\n",
    "test_data = (test_data - mean) / (std + 1e-7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy:\t0.936734693877551\n"
     ]
    }
   ],
   "source": [
    "pred_y = model2.test(test_data)\n",
    "accuracy = np.sum(pred_y == test_labels[0]) / len(pred_y)\n",
    "print('accuracy:\\t%s' % accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 总结\n",
    "实验做的很不完备，但是挺惊讶的，一个简单的 Logistic Regression 就可以达到这么高的准确度，下次用 CNN 跑一下。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "torch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
